import torch.nn as nn
import torch.nn.functional as F

class Sim_AE(nn.Module):
    def __init__(self):
        super(Sim_AE, self).__init__()
        self.enc_cnn_1 = nn.Sequential(nn.Conv2d(1, 16, 3, 1, padding=1),nn.BatchNorm2d(16),nn.ReLU()) #
        self.enc_cnn_2 = nn.Sequential(nn.Conv2d(16, 32, 3, 1, padding=1),nn.BatchNorm2d(32),nn.ReLU()) #
        self.enc_cnn_3 = nn.Sequential(nn.Conv2d(32, 64, 3, 1, padding=1),nn.BatchNorm2d(64),nn.ReLU()) #
        self.enc_cnn_4 = nn.Sequential(nn.Conv2d(64, 128, 3, 1, padding=1),nn.BatchNorm2d(128),nn.ReLU()) #
        self.enc_cnn_5 = nn.Sequential(nn.Conv2d(128, 256, 3, 1, padding=1),nn.BatchNorm2d(256),nn.ReLU()) #
        self.denc_cnn_5 = nn.Sequential(nn.ConvTranspose2d(256, 128, 3, 1, padding=1))  #
        self.denc_cnn_4 = nn.Sequential(nn.ConvTranspose2d(128, 64, 3, 1, padding=1))  #
        self.denc_cnn_3 = nn.Sequential(nn.ConvTranspose2d(64, 32, 3, 1, padding=1))  #
        self.denc_cnn_2 = nn.Sequential(nn.ConvTranspose2d(32, 16, 3, 1, padding=1))  #
        self.denc_cnn_1 = nn.Sequential(nn.ConvTranspose2d(16, 1, 3, 1, padding=1))  #

    def forward(self, x):
        enc_1 = self.enc_cnn_1(x)
        enc_2 = self.enc_cnn_2(enc_1)
        enc_3 = self.enc_cnn_3(enc_2)
        enc_4 = self.enc_cnn_4(enc_3)
        out = self.enc_cnn_5(enc_4)

        encoded_img = out

        denc_5 = self.denc_cnn_5(out)
        #denc_5 = denc_5 + enc_4
        out = F.relu(denc_5)
        denc_4 = self.denc_cnn_4(out)
        out = F.relu(denc_4)
        denc_3 = self.denc_cnn_3(out)
        #denc_3 = denc_3 + enc_2
        out = F.relu(denc_3)
        denc_2 = self.denc_cnn_2(out)
        out = F.relu(denc_2)
        denc_1 = self.denc_cnn_1(out)
        #denc_1 = denc_1 + x
        out = F.relu(denc_1)
        return encoded_img,out
